# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import logging
import os
import re

from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler
from bokeh.embed import server_document
from bokeh.server.server import Server
from tornado import web

from ... import oscar as mo
from ...utils import get_next_port

logger = logging.getLogger(__name__)
_ROOT_PLACEHOLDER = 'ROOT_PLACEHOLDER'


class BokehStaticFileHandler(web.StaticFileHandler):  # pragma: no cover
    @staticmethod
    def _get_path_root(root, path):
        from bokeh import server
        path_parts = path.rsplit('/', 1)
        if 'bokeh' in path_parts[-1]:
            root = os.path.join(os.path.dirname(server.__file__), "static")
        return root

    @classmethod
    def get_absolute_path(cls, root, path):
        return super().get_absolute_path(cls._get_path_root(root, path), path)

    def validate_absolute_path(self, root, absolute_path):
        return super().validate_absolute_path(
            self._get_path_root(root, absolute_path), absolute_path)


class MarsRequestHandler(web.RequestHandler):  # pragma: no cover
    def initialize(self, supervisor_addr):
        self._supervisor_addr = supervisor_addr

    def get_root_path(self):
        req_path = re.sub('/+', '/', self.request.path).lstrip('/')
        slash_count = sum(1 for c in req_path if c == '/')
        if slash_count > 0:
            return '../' * slash_count
        else:
            return './'

    def write_rendered(self, template, **kwargs):
        self.write(template.render(
            request=self.request, root_path=self.get_root_path(), **kwargs))

    def bokeh_server_document(self, url, resources="default", arguments=None):
        raw_script = server_document(
            f'{_ROOT_PLACEHOLDER}/{url}', relative_urls=True, resources=resources, arguments=arguments)

        # FIXME lines below hacks codes generated by bokeh to support
        #  websocket connection in proxy-passed environments
        quote = raw_script[raw_script.index(_ROOT_PLACEHOLDER) - 1]
        path_replacer = (
            f'&bokeh-app-path=" + window.location.pathname.match(/.*\\//) + "{self.get_root_path()}'
        )
        path_replacer = path_replacer.replace('"', quote)

        script = raw_script.replace(f'&bokeh-app-path=/{_ROOT_PLACEHOLDER}/', path_replacer) \
            .replace(f'{_ROOT_PLACEHOLDER}/', self.get_root_path())
        return script


class WebActor(mo.Actor):
    def __init__(self, config):
        super().__init__()
        self._config = config
        self._web_server = None

    async def start(self):
        static_path = os.path.join(os.path.dirname(__file__), 'static')
        supervisor_addr = self.address

        host = self._config.get('host') or '0.0.0.0'
        port = self._config.get('port')
        bokeh_apps = self._config.get('bokeh_apps', {})
        web_handlers = self._config.get('web_handlers', {})

        handlers = dict()
        for p, h in bokeh_apps.items():
            handlers[p] = Application(FunctionHandler(
                functools.partial(h, supervisor_addr=supervisor_addr)))

        handler_kwargs = {'supervisor_addr': supervisor_addr}
        extra_patterns = [
            (r'[^\?\&]*/static/(.*)', BokehStaticFileHandler, {'path': static_path})
        ]
        for p, h in web_handlers.items():
            extra_patterns.append((p, h, handler_kwargs))

        retrial = 5
        while retrial:
            try:
                if port is None:
                    port = get_next_port()

                self._web_server = Server(
                    handlers, allow_websocket_origin=['*'],
                    address=host, port=port,
                    extra_patterns=extra_patterns,
                    http_server_kwargs={'max_buffer_size': 2 ** 32},
                )
                self._web_server.start()
                logger.info('Mars UI started at %s:%d', host, port)
                break
            except OSError:  # pragma: no cover
                if port is not None:
                    raise
                retrial -= 1
                if retrial == 0:
                    raise


async def start(config: dict, address: str = None):
    ref = await mo.create_actor(WebActor, config=config.get('web', {}),
                                address=address)
    await ref.start()
